from collections import OrderedDict

import torch
import numpy as np
from torch.nn.utils import spectral_norm


class ResBlock(torch.nn.Module):
    """It should be a strict resnet"""

    def __init__(self, num_inputs):
        super(ResBlock, self).__init__()
        self.num_inputs = num_inputs

        dense_layer_1 = torch.nn.Linear(in_features=self.num_inputs, out_features=self.num_inputs)
        dense_layer_2 = torch.nn.Linear(in_features=self.num_inputs, out_features=self.num_inputs)

        self.model = torch.nn.Sequential(
            spectral_norm(dense_layer_1),
            torch.nn.LeakyReLU(),
            spectral_norm(dense_layer_2))
        self.latent_feature = None

    def forward(self, x):
        """
        implementing hl(x) = x+gl(x) in the paper:
        "Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness"
        https://arxiv.org/pdf/2006.10108.pdf
        refer to https://github.com/omegafragger/DDU/blob/f597744c65df4ff51615ace5e86e82ffefe1cd0f/net/resnet.py
        #
        """
        return self.model(x) + x


# TODO: maybe try this more advanced residual block
class ResBlockAdvance(torch.nn.Module):

    def __init__(self, num_inputs):
        super(ResBlockAdvance, self).__init__()
        self.num_inputs = num_inputs

        dense_layer_1 = torch.nn.Linear(in_features=self.num_inputs, out_features=self.num_inputs)
        bn_1 = torch.nn.BatchNorm1d(num_features=self.num_inputs)
        dense_layer_2 = torch.nn.Linear(in_features=self.num_inputs, out_features=self.num_inputs)
        bn_2 = torch.nn.BatchNorm1d(num_features=self.num_inputs)
        dense_layer_3 = torch.nn.Linear(in_features=self.num_inputs, out_features=self.num_inputs)
        self.activation = torch.nn.LeakyReLU()

        self.model = torch.nn.Sequential(
            spectral_norm(dense_layer_1),
            bn_1,
            spectral_norm(dense_layer_2),
            bn_2,
            spectral_norm(dense_layer_3))

    def forward(self, x):
        """
        implementing hl(x) = x+gl(x) in the paper:
        "Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness"
        https://arxiv.org/pdf/2006.10108.pdf
        refer to https://github.com/omegafragger/DDU/blob/f597744c65df4ff51615ace5e86e82ffefe1cd0f/net/resnet.py
        #
        """
        return self.activation(self.model(x) + x)

# tmp = ResBlockAdvance(32)
# abc = tmp(torch.ones([2, 32]))
# print(abc)
